import sys
import os
import time
import math
import sklearn
import json
import pandas as pd
import features_topk as ft
from sklearn.inspection import permutation_importance
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score, roc_curve, f1_score


######################################################################
#                        READ INPUT				     #
######################################################################
def read_value(string):
    if string[0] == "'" and string[-1] == "'":
        return string[1:-1]
    val = string
    try:
        val = int(string)
    except:
        try:
            val = float(string)
        except:
            pass
    return val

def load_options(fname):
    d_options = {}
    f = open(fname, "r")
    lines = f.readlines()
    #print(lines)
    f.close()
    for line in lines:
        ignore = 0
        if (len(line) > 0):
            if line[0] == "#":
                ignore = 1
        if (ignore == 0 and "\t" in line):
            line = line[:-1]
            li = line.split("\t")
            d_options[li[0]] = read_value(li[1])
    print(d_options)
    return d_options

def get_list(d):
    # File number indices;starting 0; excluding end index
    if "MAL_START" in d.keys():
        mstart = d["MAL_START"]
        mend = d["MAL_END"]
    else:
        mstart = 0
        mend = d["MAL_TOTAL"]

    if "BEN_INSTSTART" in d.keys():
        bstart = d["BEN_INSTSTART"]
        bend = d["BEN_INSTEND"]
    else:
        bstart = mend
        bend = mend+d["BEN_INSTNUM"]

    if "MAL_INSTSTART" in d.keys():
        mistart = d["MAL_INSTSTART"]
        miend = d["MAL_INSTEND"]
    else:
        mistart = 0
        miend = d["MAL_INSTNUM"]

    if not "DATA_LOC" in d.keys():
         raise Exception("Input folder path not found in options.")
         sys.exit()

    if "HOSTFTS" in d.keys():
        if d["HOSTFTS"].lower() == "true":
            hostfts = bool(True)
        else:
            hostfts = bool(False)
    else:
        hostfts = bool(False)

    print("=========Hostfts",hostfts)
    if "FOLD_TOTAL" in d:
         foldtotal = d["FOLD_TOTAL"]
    else:
         foldtotal = 10

	# 0: Binary classification; 1: Multiclass; 2: Multi class, multi target label
    if "MULTICLASS" in d:
        multi = d["MULTICLASS"] # 0: binary classification, 1: multiclass, 2: multilabel, multiclass
    else:
        multi = 0 # binary classification by default

    totalmal = (mend - mstart) * (miend+1) # Each mal instance X => X.cell & X-Y.cell; 0 <= Y < miend
    totalben = bend - bstart # Each ben instance Y => Y.cell; bstart <= Y < bend
    malfnames = []
    benfnames = []
    for (dp, dnames, fnames) in os.walk(d["DATA_LOC"]):
        for fname in fnames:
            if "-" in fname:
                if fname not in malfnames:
                   malfnames += [dp+fname]
                   print(fname+"----Malicious")
            else:
                fnum = int(fname.split(".")[0])
                if fname not in malfnames and fnum >= mstart and fnum < mend:
                    #print(fname)
                    malfnames += [dp+fname]
                    print(fname+"----Malicious")
                else:
					# Take benign data only for binary classification
                    if fname not in benfnames and fnum >= bstart and fnum < bend and multi == 0:
                        benfnames += [dp+fname]
                        print(fname+"----Benign")

    print("Malstart:%d\nMalend: %d\nBenstart:%d\nBenend:%d\n"%(mstart,mend-1,bstart, bend-1))
    readm = len(malfnames)
    readb = len(benfnames)
    if not readm == totalmal:
        print("Malware files read (%d) and options spec (%d) mismatch"%(readm, totalmal))
        #malfnames.sort()
        #print(malfnames)
    if not readb == totalben:
        print("Benign files read (%d) and options spec (%d) mismatch"%(readb, totalben))
        #benfnames.sort()
        #print(benfnames)
    print("Read:\n Malware files: %d\n Benign files:%d\n CV Folds:%d\n"%(readm, readb, foldtotal))
    return [malfnames, benfnames, foldtotal, multi, mend, miend, totalben, hostfts]


# Read AVCLASS results; map md5 hash -> sha256
def check_dependencies(d):
	if "MULTICLASS" in d:
		if d["MULTICLASS"] == 1:
			if "AVCLASS_FILE_LOC" in d:
				location = d["AVCLASS_FILE_LOC"]
			else:
				location = os.getcwd()+"avclass/AVCLASS.classes"
				if not os.path.exists(location):
					print("AVCLASS label file required for multilabel classification! : ",location)
					return False

			if "VT_REPORTS_LOC" in d:
				reportloc = d["VT_REPORTS_LOC"]
			else:
				reportloc = os.getcwd()+"avclass/reports.avclass"
				if not os.path.exists(reportloc):
					print("VT detail reports required for AVCLASS labels of malware samples!: ", reportloc)
					return False
		else:
			print("Set 'MULTICLASS = 1' in options file to run multi label classifier\n (Mode=0 : Binary classification, Mode=1 : Multi label classification)")
			return False
	else:
		print("Can't reach here")
		return False

	return True

def getfiles(d):
	if "AVCLASS_FILE_LOC" in d:
		location = d["AVCLASS_FILE_LOC"]
	else:
		location = os.getcwd()+"avclass/AVCLASS.classes"

	if "VT_REPORTS_LOC" in d:
		reportloc = d["VT_REPORTS_LOC"]
	else:
		reportloc = os.getcwd()+"avclass/reports.avclass"

	return [location, reportloc]

def get_sha_md5map(reports="avclass/reports.avclass"):
	# Map SHA256 - md5
	md5_sha = dict()
	sha_md5 = dict()
	skipped = 0
	with open(reports, "r") as f:
		for line in f.readlines():
			datadt = json.loads(line)
			#print(datadt)
			if "data" in datadt:
				if "attributes" in datadt["data"]:
					attdt = datadt["data"]["attributes"]
					keys = attdt.keys()
					if not ("md5" in keys and "sha256" in keys):
						skipped += 1
						#print(line)
						print("No MD5/SHA hash in report , skipping....\n")
						continue
					else:
						md5 = attdt["md5"]
						sha = attdt["sha256"]
						if md5 not in md5_sha:
							md5_sha[md5] = sha
						if sha not in sha_md5:
							sha_md5[sha] = md5

	print("Not found: ", skipped)
	print("# of md5-labels read: ", len(sha_md5))
	return [md5_sha, sha_md5]

# return: sha256 malware hash: [class labels]
def read_multilabels(avclassfile, reports):
	sha_md5 = get_sha_md5map(reports)[0] # needs md5 -> sha map
	classdt = dict()

	with open(avclassfile, "r") as ff:
		for line in ff.readlines():
			if "," in line:
				line = line.replace(","," ")
			line = line.rstrip().split(" ")
			md5 = line[0]
			if line[1].isnumeric():
				labellst =  line[2:]
			else:
				labellst =  line[1:]
			#print(labellst)
			labels = labellst[0::2]
			print("Labels: ", labels)
			sha = None
			if md5 in sha_md5:
				sha = sha_md5[md5]
				if sha not in classdt:
					classdt[sha] = labels
				#else:
				#	print("Duplicate entries for sample!Skipping", sha, labels, classdt[sha])
			else:
				print("VT report not found!! Skipping sample: ", md5, labels)

			#print(md5, labels)

	print("# Samples with AVCLASS labels: ", len(classdt))
	return classdt

# Read map index file in expdata/new_map_malinst_source
# return: {mapi: [labels, shahash]}
# NOTE: Needs "MAPFILE" location in options file
def sha_mapi(d, classdt, multiclass=False):
	sha_label_map = dict()
	uniq_labels = []
	if "MAPFILE" in d:
		fname = d["MAPFILE"]
		with open(fname, "r") as f:
			for line in f.readlines():
				line = line.rstrip().split("\t")
				shahash = line[0]
				mapi = int(line[1])
				if not multiclass: # It is Multi label case
					labels = ["unknown"]

					if shahash in classdt.keys():
						labels = classdt[shahash] # Get multilabels
					else:
						for sha, l in classdt.items():
							if shahash in sha or shahash == sha:
								labels = l
								shahash = sha
								break
					if shahash not in sha_label_map:
						sha_label_map[mapi] = [labels, shahash]
						for lab in labels:
							if lab == "unknown":
								print(shahash)
							uniq_labels += [lab]
				else:
					# For multiclass, just return mapi-sha mapping
					if shahash not in sha_label_map:
						sha_label_map[shahash] = mapi

		#print(sha_label_map)
		print("Total Malware-AVCLASS mapping: ",len(sha_label_map))
		print("Labels in dataset: ", set(uniq_labels))
		assert len(sha_label_map) == d["MAL_TOTAL"]
	else:
		print("Aborting! Map file not available! Run cell extraction script.")
		return None
	return sha_label_map

# Return: {sha: malware family}
def get_family(sha_md5, malfamily="avclass/Malware.family"):
	sha_malfam = dict()
	md5_fam = dict()
	with open(malfamily, "r") as f:
		for line in f.readlines():
			line = line.rstrip().split("\t")
			md5 = line[0]
			fam = line[1]
			if "SINGLETON" in fam: #Just take known families for now
				continue
			print(md5, fam)
			if md5 not in md5_fam:
				md5_fam[md5] = fam

	print("Unique md5-family mapping: ", len(md5_fam))
	return md5_fam

# Get AVCLASS Malware families for multiclass labelling:
# location: avclass/Malware.family
def get_mal_families(d):
	mal_i_fam = dict()
	# 1. Get all hashes sha-md5 mapping
	sha_md5 = get_sha_md5map()[1] # needs sha-> md5 map
	# 2. Get sha-malware index mapping
	sha_mal_i = sha_mapi(d, sha_md5)
	# 3. Get family name from md5
	md5_fam = get_family(sha_md5)
	# 4. Get mal_i and family
	for sha, mal_i in sha_mal_i.items():
		#print(sha, mal_i)
		if sha in sha_md5:
			md5 = sha_md5[sha]
			if md5 in md5_fam:
				fam = md5_fam[md5]
			else:
				fam = "mal_"+str(sha_mal_i[sha])
				print("SINGLETON replaced with malware index: ", fam)
			if mal_i not in mal_i_fam:
				mal_i_fam[str(mal_i)] = fam
		else:
			print("Match for hash NotFound! : ", sha)

	print("Sha-md5 and Sha-malwareindex: ", len(sha_md5), len(sha_mal_i))
	print("Malware to family name mapping: ", len(mal_i_fam))
	return mal_i_fam

# Count malware binary-family distribution
def malware_distribution(labeldt):
	count_fam = dict()
	# k : file name, fam: [fam, malware_index]
	for k, fam in labeldt.items():
		if not "-" in k:
			if fam[0] not in count_fam:
				count_fam[fam[0]] = 1
			else:
				count_fam[fam[0]] += 1
	lst = []
	# Family- unique malware in that family
	for fam, count in count_fam.items():
		lst += [(count, fam)]
	lst.sort()
	print("Family-malware distriubtion: ", lst)
	return
######################################################################
#                        LABELLING				     #
######################################################################
# MULTICLASS: 0
def label_binary(malfnames, benfnames):
	print("Binary labelling.....")
	labeldt = dict()
	for mf in malfnames:
		if mf not in labeldt:
			labeldt[mf] = 1
			print(mf,"1")

	for bf in benfnames:
		if bf not in labeldt:
			labeldt[bf] = 0
			print(bf,"0")

	#print(labeldt)
	print("Labelled dataset: ", len(labeldt))
	# Assign labels based on file names: return: filename -> label
	return labeldt

# MULTICLASS: 1
def label_multiclass(d, malfnames, benfnames, mal_family=False):
	print("Multiclass labelling.......Malware family?:", mal_family)
	labeldt = dict()
	mallabels = []
	if mal_family:
		mal_i_fam = get_mal_families(d)

	for mf in malfnames:
		X = None
		fname = mf
		mf = mf.split("/")[-1].split(".")[0]
		if "-" in mf:
			X = str(mf.split("-")[0])		#X-Y.cell -> X is the malware label
		else:
			X = str(mf)						#X.cell -> X

		if fname not in labeldt:
			if mal_family:
				if X in mal_i_fam:
					fam = mal_i_fam[X]
					print("Family Multiclass Labelling: ", fname, fam)
					labeldt[fname] = [fam, X] # family name, malware binary index
				else:
					print("Don't expect to come here. All malware indices must have mapping")
					labeldt[fname] = ["mal_"+str(X), X]
			else:
				print("Binary Multiclass Labelling: ", fname, X)
				labeldt[fname] = X

		if X not in mallabels:
			mallabels += [X]

	print("Labelled dataset: ", len(labeldt))
	print("Unique malware noted: ", len(mallabels))
	if mal_family:
		malware_distribution(labeldt)

	return labeldt

def transform_labels(avlabels):
	alllabels = []
	for mi, val in avlabels.items():
		labels = val[0]
		for lab in labels:
			if lab not in alllabels:
				alllabels.append(lab)
	#print(alllabels)
	uniqlabels = list(set(alllabels))
	print(uniqlabels)
	mlb = MultiLabelBinarizer()
	mlb.fit([uniqlabels])
	print("MULTI LABEL CLASSES: ",mlb.classes_)
	#print(result)
	return [mlb, mlb.classes_]

def generate_multilabel(elements, mlb, classorder):
	multilabel = mlb.transform([elements])
	print(elements, multilabel)
	print(type(multilabel))
	return multilabel

# MULTICLASS: 1
def label_multiclass_multilabel(malfnames, benfnames, avlabels):
	# avlabels-> {'malware mapindex/X in X-Y.cell': [avclass labels, sha]}
	print("Multiclass Multilabelling.......")
	labeldt = dict()
	mallabels = []
	# Get multiclass labels from avclass labels
	[mlb, classorder] = transform_labels(avlabels)

	for mf in malfnames:
		X = None
		fname = mf
		mf = mf.split("/")[-1].split(".")[0]
		if "-" in mf:
			X = int(mf.split("-")[0])		#X-Y.cell -> X is the malware label
		else:
			X = int(mf)						#X.cell -> X
		print("Labelling: ", fname, X)
		if fname not in labeldt:
			if X in avlabels:
				labeldt[fname] = generate_multilabel(avlabels[X][0], mlb, classorder)
			else:
				print("Label for malware index: %d not available! Check mapping!"%X)
		if X not in mallabels:
			mallabels += [X]

	print("Labelled dataset: ", len(labeldt))

	return [labeldt, classorder, mlb]

# Check cell directions for top-3 tor conns
def check_cells(torconn, connlen):
	if connlen == 0:
		return False

	inseen = False
	outseen = False
	for line in torconn:
		if "\t1" in line and not outseen:
			outseen = True
		if "\t-1" in line and not inseen:
			inseen = True
		if inseen and outseen:
			return True
	return False

# Split malware and benign PCAPs/cell files by 70-30:Train-Test
# Eg: Benign: Test benign files: [157 - 1228] 30% go to test and remaining goes to train
# Eg: For each malware binary(D5): 3 or 4 PCAPs go to train (70%) and 2 or 1 PCAP go to test (30%)
def get_pcapsplit(labeldt, maltotal, bentotal, malinst, ds="D5"):
	train = []
	test = []
	seen = dict()
	switch = 1 # Switch to control 3:2 and 4:1 pcap splits/binary
	train_ben = 0
	test_ben = 0
	fourone = []
	threetwo = []
	ben_testindx_end = maltotal + math.floor((bentotal * 30) / 100)
	print(ben_testindx_end)
	#assert ben_testindx_end == 1541 #test for D5
	mal_train_inst = math.ceil((maltotal*malinst*70)/100)
	#assert mal_train_inst == 550 #test for D5

	for fpath, label in labeldt.items():
		bnum = fpath.split("/")[-1].strip(".cell")
		# Benign file split
		if int(label) == 0:
			fnum = int(bnum)
			if fnum >= maltotal and fnum <= ben_testindx_end:
				test += [fpath]
				continue
			else:
				train += [fpath]
		else:
			if "-" in bnum:
				bnum = bnum.split("-")[0]

			print(fpath, bnum)
			# For each malware binary, split pcaps/cell files into 70-30
			if bnum not in seen:
				seen[bnum] = dict()
				seen[bnum]["train"] = 1
				seen[bnum]["test"] = 0
				train += [fpath]
				continue
			else:
				pcapstaken_train = seen[bnum]["train"]
				# D5: Split PCAPs for each binary into train-test: 70-30
				if pcapstaken_train < 3:
					train += [fpath]
					seen[bnum]["train"] += 1
					continue
				else:
					if pcapstaken_train == 3:
						if switch and len(train) < mal_train_inst: # 70% of total PCAPs should go in training
							train += [fpath]
							switch = 0
							seen[bnum]["train"] += 1
						else:
							test += [fpath]
							switch = 1
							seen[bnum]["test"] += 1
					else:
						test += [fpath]
						seen[bnum]["test"] += 1


	print("PCAP splits/binary: Total: ", len(seen))

	# D5 will have some binaries with PCAPs split 3:2 and others 4:1 train:test ratio
	bins_3_2 = 0
	bins_4_1 = 0
	for k, v in seen.items():
		print("Malware binary: ", k, "PCAPs in Train: ", v['train'], ", Test:", v['test'])
		if ds == "D5":
			if v['train'] == 3:
				threetwo += [k]
				bins_3_2 += 1
			else:
				fourone += [k]
				bins_4_1 += 1
	if ds == "D5":
		print("(D5) Binaries with 3:2 split => ", bins_3_2,"\nBinaries with 4:1 split => ", bins_4_1)
	print("Binaries with 4:1 -", fourone)
	print("Binaries with 3:2 -", threetwo)
	print("Benign PCAPs split: TRAIN: ", train_ben, ", TEST: ", test_ben)
	print("Total PCAPs in TRAIN (malware+benign): ", len(train), "\nTotal PCAPs in TEST (malware+benign): ", len(test))
	return train, test


######################################################################
#                        FEATURE EXTRACTION			     #
######################################################################
def get_topk_conn_cells(data, topk=3, cut=False):
	topkconns = []
	torconn1 = []
	torconn2 = []
	torconn3 = []
	for line in data:
		if "HOST_FTS" in line:
			continue
		lsplit = line.rstrip().split("#")
		if cut:
			#print(lsplit)
			ts = float(lsplit[1].split("\t")[0])
			if ts > 360.0:
				#print("Skipping cells > 6mts: ", line)
				continue

		if "1#" in line:
			torconn1 += [lsplit[1]]

		elif "2#" in line and topk >= 2:
			torconn2 += [lsplit[1]]

		elif "3#" in line and topk == 3:
			torconn3 += [lsplit[1]]

	#print("T1: ", torconn1)
	t1 = len(torconn1)
	t2 = len(torconn2)
	t3 = len(torconn3)
	fullt1 = check_cells(torconn1, t1)
	fullt2 = check_cells(torconn2, t2)
	fullt3 = check_cells(torconn3, t3)
	print("Top 3 Tor connection cells noted: ", t1, t2, t3)
	print("Bidirectional cells: t1:",fullt1,"t2:",fullt2,"t3:",fullt3)

	if t1 > 0 and t2 > 0 and t3 > 0:
		if fullt1:
			topkconns += [torconn1]
		if fullt2:
			topkconns += [torconn2]
		if fullt3:
			topkconns += [torconn3]
	elif t2 == 0 and t1 > 0 and fullt1:
		topkconns = [torconn1]
	elif t2 > 0 and t3 == 0:
		if fullt1:
			topkconns += [torconn1]
		if fullt2:
			topkconns += [torconn2]

	print("Topkconns taken: ", len(topkconns))
	return topkconns

def extract_features(labeldt, multiclass, hostfts, top=3, checklist=[], classorder=[], trainmulti=True):
	feats = [] # All features
	c = 0
	totalcols = []
	topkcount = []
	topktotal = 0
	skipped = 0
	spcaps = []
	filesread = []
	all_multilabels = []
	famcount = dict()
	publicfile = 0
	maldt = dict()
	print("Extracting features for Classification Mode: ", multiclass)
	print("Using Host features for training?: ", hostfts)

	for fpath, label in labeldt.items():
		print("*",fpath, label)
		if not checklist == [] and fpath not in checklist:
			skipped += 1
			spcaps += [fpath.split("/")[-1]]
			continue
		data = open(fpath).readlines()
		filesread += [fpath.split("/")[-1]]

		# Extract host features only
		#if hostonly:
		#	Hfts = ft.TOTAL_FEATURES(data, False, True)
		#	print("Host only case: ", Hfts)
		#	assert len(Hfts) > 0
		#	assert len(Hfts) == 40
		#	totalcols += [len(Hfts)]
		#	feats += [Hfts+[label]]
		#	continue

		# Extract TopK=3 highly active Tor connections
		if "-" not in fpath:
			publicfile = int(fpath.split("/")[-1].split(".cell")[0])
		print("Requested: ", top)
		#print("Public data file?: ", publicfile)
		topkconns = get_topk_conn_cells(data, top)
		print("Taking Top: ", len(topkconns))
		topkcount += [len(topkconns)]
		if len(topkconns) == 0:
			print("No Tor connections in this cell file")
			continue

		# Extract features for TopK+Hostfts optionally
		for conndata in topkconns:
			tcp_dump = conndata
			print("Total cells in Tor connection: ", len(tcp_dump))
			fts = []
			#print(fpath)
			topktotal += 1
			fts = ft.TOTAL_FEATURES(tcp_dump, False)
			print("CONNECTION-LEVEL: ", fts)

			if hostfts:
				# Extract only host fts here and add to conn-level fts
				Hfts = ft.TOTAL_FEATURES(data, False, onlyhost=True)
				print("HOST-LEVEL: ", Hfts)
				assert len(Hfts) > 0
				fts += Hfts
				assert len(fts) == 215 or len(fts) == 40
			print("All Features: ", fts)
			print("Total features used: ", len(fts))

			totalcols += [len(fts)]
			#print("Extracting features: ", fpath, label)

			# Multi label classification
			if multiclass == 1:
				if trainmulti:
					all_multilabels += list(label)
				else:
					all_multilabels += [label]
				feats += [fts]
			else:
				# Binary classification
				feats += [fts+[label]]

	print("Total features: ", max(totalcols))
	#print(feats)
	print("Total files for which features extracted: ", len(feats))
	#print(all_multilabels)
	featdf = pd.DataFrame(feats)
	##print("Topk connection distribution per file: ", topkcount.sort())
	if multiclass == 1:
		if list(classorder) == [] and trainmulti:
			print("Classorder needed to set labels in DF!")
			return None
		else:
			if not trainmulti: # Return only features for testing
				labeldf = pd.DataFrame(all_multilabels)
				fullfeatdf = pd.concat([featdf, labeldf], axis=1)
				fullfeatdf.columns = [*fullfeatdf.columns[:-1], 'binary']
				return fullfeatdf
			return [feats, all_multilabels, filesread] #[featdff, len(featdf.columns)]
	else:
		featdf.columns = [*featdf.columns[:-1], 'target']
		print("Binary classification dataframe Size and shape: ", featdf, featdf.shape)
		if multiclass == 0:
			malrows = featdf[featdf['target'] == 1]
			benrows = featdf[featdf['target'] == 0]
			print("Malrows: ", malrows.shape, " Benrows: ", benrows.shape)

		return featdf

######################################################################
#                     PERFORMANCE EVALUATION			     #
######################################################################
def evaluate_model(y_test, y_pred, probs, multiclass=False):
    results = {}
    if multiclass:
        results['recall'] = recall_score(y_test, y_pred, average='micro')
        results['precision'] = precision_score(y_test, y_pred, average='micro')
    else:
        results['recall'] = recall_score(y_test, y_pred)
        results['precision'] = precision_score(y_test, y_pred)
    #results['roc'] = roc_auc_score(y_test, probs)
    print("Micro Recall: ", results['recall'])
    print("Micro Precision: ", results['precision'])
    #print("ROC/AUC Score: ", results['roc'])

    return results


def sklearn_featimpo(model):
	print("====SKLEARN Feature Importance====")
	importance = model.feature_importances_
	# summarize feature importance
	for i,v in enumerate(importance):
		print('Feature: %0d, Score: %.5f' % (i,v))
	return

def permute_fimp(model, X, y):
	print("==== Permutation Feature Importance =====")
	# perform permutation importance
	results = permutation_importance(model, X, y, scoring='accuracy')
	# get importance
	importance = results.importances_mean
	importance.sort()
	# summarize feature importance
	for i,v in enumerate(importance):
		print('Feature: %0d, Score: %.5f' % (i,v))
	return

def write_output(ff, score, tp, fp, tn, fn, results):
	prec = str(results['precision'])
	recall = str(results['recall'])
	ff.write(str(score)+"\t"+str(tp)+"\t"+str(fp)+"\t"+str(tn)+"\t"+str(fn)+"\t")
	ff.write(prec+"\t"+recall+"\n")
	return ff

def output_avg(total, ag_res1, ag_res2, fimp1, fimp2, auto_cmatrix, bestmodel, perf, auc_score, ff):
	print(auto_cmatrix)
	ff.write("-----------------Autogluon----------------\n")
	ff.write("Best model confusion matrix: \n")
	[tn,fp,fn,tp] = auto_cmatrix
	fpr = float(fp/(fp+tn)*100)
	ff.write("TN: "+str(tn)+" FP: "+str(fp)+" FN: "+str(fn)+" TP: "+str(tp)+"\n")
	ff.write("::Model performance on test data::\n")
	ff.write("AUC Score: "+str(auc_score)+"\n")
	ff.write("FPR: "+str(fpr)+"\n")
	ff.write("Best model: "+ bestmodel+" \n")
	ff.write("Performance summary: "+str(perf)+" \n")
	ff.write(str(ag_res1))
	if not fimp1 == None:
		ff.write("*Ft impo*\n")
		ff.write(str(fimp1.head(20))+"\n")
	ff.write("\n::Stacking & Weighted Ensembling of Models::\n")
	ff.write(str(ag_res2))
	if not fimp2 == None:
		ff.write("*Ft impo*\n")
		ff.write(str(fimp2.head(20))+"\n")
	ff.write("--------------------------------------------\n")
	ff.close()
	return

def output_multilabel(mllabel_op, outfolder, maltotal, malinst, multiclassmode, hostfts):
	fname = "MultilabelTraining_D"+str(malinst+1)+"_"
	if hostfts:
		fname += "host"
	ff = open(outfolder+fname+".score", "a+")
	techniques = ["Binary Relevance", "Classifier Chains", "Label Powerset"]
	i = 0
	assert len(techniques) == len(mllabel_op)
	for res in mllabel_op:
		acc= str(res[0])
		hloss= str(res[1])
		mprec= str(res[2])
		mrecall= str(res[3])
		mf1= str(res[4])
		modeltype= res[5]
		technique = techniques[i]
		ff.write("Technique: %s, Model: Random Forest\n"%(technique))
		ff.write("Accuracy: %s; Hamming Loss: %s; Micro-Prec: %s; Micro-Recall: %s; Micro-F1: %s\n\n\n"%(acc, hloss, mprec, mrecall, mf1))
		i += 1
	return
